# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from __future__ import annotations
from functools import lru_cache
from torch import Tensor


from megatron.core.models.common.embeddings.yarn_rotary_pos_embedding import YarnRotaryEmbedding

##HACK(huang.huang): increase lru_cache maxsize from 32 to 64, to prevent cache missed when layer>32 in a pp stage 
@lru_cache(maxsize=64)
def YarnRotaryEmbedding_forward(self, max_seq_len: int, offset: int = 0) -> Tensor:
    return self._orig_forward(max_seq_len, offset)
## HACK(huang.huang)

from transformer_engine.musa.pytorch.utils import replace_attr
replace_attr(YarnRotaryEmbedding,"forward", YarnRotaryEmbedding_forward)